import pandas as pd
import ollama
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm

# --- Configuration ---
INPUT_CSV = 'data/raw/raw_bc_data.csv' 
OUTPUT_CSV = 'data/model_outputs/test_classification_results.csv'
TEXT_COLUMN = 'text' 
GOLD_STANDARD_COLUMN = 'bin' 

# Model Names
CONFLIBERT_HF_MODEL = 'eventdata-utd/conflibert-binary-classification'
GEMMA_OLLAMA_MODEL = 'gemma2:9b-instruct-q4_K_M'
LLAMA_OLLAMA_MODEL = 'llama3.1:8b-instruct-q4_K_M'

# --- 1. Hardware Acceleration Setup ---
def get_device():
    """Checks for available hardware accelerators."""
    if torch.backends.mps.is_available():
        print("✅ Apple Silicon MPS device found. Using MPS.")
        return torch.device("mps")
    if torch.cuda.is_available():
        print("✅ NVIDIA CUDA device found. Using CUDA.")
        return torch.device("cuda")
    print("⚠️ No hardware accelerator found. Using CPU.")
    return torch.device("cpu")

# --- 2. Hugging Face Model Inference Function ---
def classify_with_conflibert(texts, tokenizer, model, device, batch_size=16):
    """
    Classifies a list of texts using the ConfliBERT model in batches.
    """
    model.to(device)
    model.eval()
    predictions = []
    
    print("\nRunning classification with ConfliBERT...")
    for i in tqdm(range(0, len(texts), batch_size), desc="ConfliBERT Batches"):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)
            
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        predictions.extend([model.config.id2label[p.item()] for p in preds])
        
    return predictions

# --- 3. Ollama Model Inference Function ---
PROMPT_TEMPLATE = """You are an expert text classifier. Your task is to classify the following text as either 'Conflict' or 'Not Conflict'.
'Conflict' refers to texts about war, violence, political unrest, or significant social tensions.
Your response MUST be a single JSON object and nothing else. Do not add explanations or markdown formatting.
Example format: {{"classification": "Conflict"}}
Text to classify: "{text}"
"""

def classify_text_with_ollama(client, model_name, text):
    """Sends text to an Ollama model for classification and expects a JSON response."""
    try:
        response = client.chat(
            model=model_name,
            messages=[{'role': 'user', 'content': PROMPT_TEMPLATE.format(text=text)}],
            format='json'
        )
        data = json.loads(response['message']['content'])
        return data.get('classification', 'Error')
    except (json.JSONDecodeError, KeyError, Exception) as e:
        print(f"\nError processing with {model_name}. Details: {e}")
        return 'Error'

# --- 4. Main Execution ---
def main():
    """Main function to run the classification pipeline."""
    device = get_device()

    print(f"Loading data from {INPUT_CSV}...")
    try:
        df = pd.read_csv(INPUT_CSV)
    except FileNotFoundError:
        print(f"❌ Error: The file {INPUT_CSV} was not found.")
        return
        
    if TEXT_COLUMN not in df.columns or GOLD_STANDARD_COLUMN not in df.columns:
        print(f"❌ Error: CSV must contain '{TEXT_COLUMN}' and '{GOLD_STANDARD_COLUMN}' columns.")
        print(f"Available columns: {df.columns.tolist()}")
        return

    # Ensure the text column is string type and handle empty texts
    texts_to_process = df[TEXT_COLUMN].fillna('').astype(str).tolist()

    # --- Load Models ---
    print("\n--- Loading Models ---")
    print(f"Loading tokenizer for {CONFLIBERT_HF_MODEL}...")
    conflibert_tokenizer = AutoTokenizer.from_pretrained(CONFLIBERT_HF_MODEL)
    print(f"Loading model {CONFLIBERT_HF_MODEL}...")
    conflibert_model = AutoModelForSequenceClassification.from_pretrained(CONFLIBERT_HF_MODEL)
    
    print("Initializing Ollama client...")
    ollama_client = ollama.Client()

    # --- Run Classifications ---
    print("\n--- Starting All Classifications ---")
    conflibert_results = classify_with_conflibert(texts_to_process, conflibert_tokenizer, conflibert_model, device)

    gemma_results = []
    llama_results = []
    for text in tqdm(texts_to_process, desc="Ollama (Gemma & Llama)"):
        gemma_results.append(classify_text_with_ollama(ollama_client, GEMMA_OLLAMA_MODEL, text))
        llama_results.append(classify_text_with_ollama(ollama_client, LLAMA_OLLAMA_MODEL, text))
    
    # --- Add new columns to the DataFrame ---
    df['conflibert'] = conflibert_results
    df['gemma'] = gemma_results
    df['llama'] = llama_results

    # --- Save the results and finish ---
    print(f"\n✅ Classification complete. Saving updated data to {OUTPUT_CSV}...")
    df.to_csv(OUTPUT_CSV, index=False)
    print("Data saved successfully.")
    print("\nTo generate performance metrics, now run the 'generate_metrics.py' script.")


if __name__ == '__main__':
    main()
